-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llama] Store KV Cache on CPU and Use PyTorch SPDA
for Next token generation
#1182
base: main
Are you sure you want to change the base?
Conversation
@hshen14 @luoyu-intel for awareness |
@mandy-li @libinta @dvarshney-habana This is the first PR of system optimization from intel neural compressor(INC) team, could you give a review? Experiments of Llama2 on single Gaudi2 card with Xeon 8380 host. With offloading KV Cache and SDPA to CPU, we improve the context limit from 26k(input:10k+output:16k) to 310k(input:10k+output:300k).
|
Please sync your PR with main/upstream and fix any merge conflicts. Thank you. |
done. |
@zhentaoyu this PR also has merge conflict with main, could you please take a look at the differences? |
|
Hi, @imangohari1, I have updated the PR (see descriptions). Could you please retake a look when you have free time? Please let me know if you have more comments or need more tests. Thanks a lot. |
else: | ||
unwrap_deepspeed_model(self).allocate_kv_cache( | ||
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From line 1096 to 1107, I would like to suggest to change like this.
if not is_greedy_or_beam_and_bucket:
cache_device = "hpu"
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
device=cache_device
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have updated it in 74e94ff. However, I can not remove the else line because I only modified the modeling_llama.py
for this experimental feature.
@zhentaoyu Do you have a use case for "It's an option for long-context inference or generation when a single hpu card OOM." ? |
Signed-off-by: Yu Zhentao <[email protected]>
Signed-off-by: Yu Zhentao <[email protected]>
Signed-off-by: Yu Zhentao <[email protected]>
Signed-off-by: Yu Zhentao <[email protected]>
Hi @yeonsily, thanks for your comment. Yes, I add a case in README and update the results in the PR description. |
else: | ||
with ht.sdp_kernel(enable_recompute=flash_attention_recompute): | ||
else: | ||
if kv_cache_on_host: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain what's the case switching kv_cache device? I thought line 656 is the case only when line 658.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this pr, we make kv cache store on cpu and do cpu sdpa only when generating the next token. The first token or prefill stage is performed on HPU due to its powerful computation ability under long-context scenario (long prompt in most cases). The full pipeline diagram shows on the pr description.
So line 658 tells the machine it can do pytorch-cpu sdpa (flash-attn) only when kv_cache_on_host & in next-token generation & inference stage. Otherwise, it will transfer the kv-cache to hpu device if need for its original operations.
Please let me know if you need more explanations or have some suggestions. Thanks.
@yeonsily the similiar features already available in tensorrt-llm https://nvidia.github.io/TensorRT-LLM/kv_cache_reuse.html#offloading-to-host-memory |
What does this PR do?
Results
python run_generation.py --model_name_or_path meta-llama/Llama-2-7b-hf --max_new_tokens 4096 --bf16 --use_kv_cache --attn_softmax_bf16 --reuse_cache --do_sample --prompt "Tell me somethings about Intel"
--kv_cache_on_host
```bash Stats: -------------------------------------------------------------------------------------------------------------- Throughput (including tokenization) = 2.132539697795915 tokens/second Number of HPU graphs = 14 Memory allocated = 12.68 GB Max memory allocated = 12.77 GB Total memory available = 94.62 GB Graph compilation duration = 5842.699780527037 seconds~~ -------------------------------------------------------------------------------------------------------------- ```update 4b0fa1a
--kv_cache_on_host
Limitations
--use_hpu_graphs
because it has host-device memory transfer in the self-attn forward process.cc @airMeng and @luoyu-intel
Update
Yi-34b-chat
on gaudi-2 with ~11k input + 5k outputcommand:
python run_generation.py \ --model_name_or_path 01-ai/Yi-34B-Chat \ --use_kv_cache \ --bf16 \ --attn_softmax_bf16 \ --reuse_cache \ --do_sample \ --dataset_name emozilla/pg19-test \ --batch_size 1 \ --max_input_tokens 11200 \ --column_name "text" \ --dataset_max_samples 1 \ --warmup 0 \ --n_iterations 1 \ --max_new_tokens 5000 \ --kv_cache_on_host
kv_cache_on_host
:kv_cache_on_host
:Stats: ---------------------------------------------------------------------- Throughput (including tokenization) = 1.2790787964372536 tokens/second Total runtime for dataset: 3909.073683977127 Memory allocated = 90.72 GB Max memory allocated = 91.63 GB Total memory available = 94.62 GB Graph compilation duration = 3907.185397926951 seconds ----------------------------------------------------------------------
kv_cache_on_host
:--max_input_tokens 11200 --max_new_tokens 10000
Stats: ---------------------------------------------------------------------- Throughput (including tokenization) = 1.2790787964372536 tokens/second Total runtime for dataset: 3909.073683977127 Memory allocated = 90.72 GB Max memory allocated = 91.63 GB Total memory available = 94.62 GB Graph compilation duration = 3907.185397926951 seconds ----------------------------------------------------------------------